import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class NTXentLoss_poly_4D(torch.nn.Module):

    def __init__(self, device, batch_size, temperature,patch, use_cosine_similarity):
        super(NTXentLoss_poly_4D, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device
        self.patch = patch
        self.softmax = torch.nn.Softmax(dim=-1)
        self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
        self.similarity_function = self._get_similarity_function(use_cosine_similarity)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    def _get_similarity_function(self, use_cosine_similarity):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_simililarity
        else:
            return self._dot_simililarity

    def _get_correlated_mask(self):
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = torch.from_numpy((diag + l1 + l2))
        mask = (1 - mask).type(torch.bool)
        return mask.to(self.device)

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
        # x shape: (N, 1, C)
        # y shape: (1, C, 2N)
        # v shape: (N, 2N)
        return v

    def _cosine_simililarity(self, x, y):
        # x shape: (N, 1, C)
        # y shape: (1, 2N, C)
        # v shape: (N, 2N)
        v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, zis, zjs):
        # batch * patch * seq_length * feature -> batch * patch * seq_length
        # print( "4D",zis.size(),zjs.size() )
        # zis = torch.squeeze(zis)
        # zjs = torch.squeeze(zjs)
#         print('zis,zjs.shape',zis.size(),zjs.size()) # 原来：batch * seq_length | 后来：batch * patch * seq_length 
        representations = torch.cat([zjs, zis], dim=0) # 原来：2 * batch * feature | 后来：(2 * batch) * patch * seq_length
#         print('representation: ',representations.size())
        similarity_matrix = torch.Tensor(2*self.batch_size,self.patch,2*self.batch_size).to(self.device)
        for i in range(0,self.patch):
            similarity_matrix[:,i,:] = self.similarity_function(representations[:,i,:], representations[:,i,:]) # 这里应该是每个patch计算每个patch的，
        # 所以维度应该是2*batch * patch * 2*patch
#         print('similarity_matrix: ',similarity_matrix.size())
        positives = torch.Tensor(2 * self.batch_size, self.patch,1)
        # filter out the scores from the positive samples
#         l_pos = torch.diag(similarity_matrix, self.batch_size) 
#         r_pos = torch.diag(similarity_matrix, -self.batch_size)
#         positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) # 提取的是当前正例的部分
        for i in range(0,self.patch):
            l_pos = torch.diag(similarity_matrix[:,i,:],self.batch_size)
            r_pos = torch.diag(similarity_matrix[:,i,:],-self.batch_size)
#             print('l_pos : ',l_pos.size(),' r_pos: ', r_pos.size())
            positives[:,i,:] = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
       # 到目前为止，计算得到正例的结果，positive的结果

        # 这里是计算负例的相似度：
        negatives = torch.Tensor(2* self.batch_size, self.patch,2*self.batch_size - 2)

#         negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1) # 提取的是当前负例的部分，刨除了我们的相同batch对应的正例
        for i in range(0,self.patch):
            t_neg = similarity_matrix[:,i,:]
        
            neg = t_neg[self.mask_samples_from_same_repr].view(2 * self.batch_size,-1)
            negatives[:,i,:] = neg
            
        logits = torch.cat((positives, negatives), dim=2).to(self.device)# 得到[positive,negatives] 原来：2*batch*(2*batch -1)|后来：(2*batch) * patch * (2 * batch)
        logits /= self.temperature

        """Criterion has an internal one-hot function. Here, make all positives as 1 while all negatives as 0. """
#         labels = torch.zeros(2 * self.batch_size).to(self.device).long() # 可以计算的原：2*batch,1|后来：2*batch * patch * 1
            ## pay attention to this!!!, 试一下是self.batch_size合适还是self.batch_size * patch合适
        labels = torch.zeros(2 * self.batch_size,2*self.batch_size - 1).to(self.device).long()
#         print("lables: ", labels.size(),'logits: ',logits.size())
        CE = self.criterion(logits, labels)  # 使用crossEntropy写所有的部分，
        
        onehot_label = torch.cat((torch.ones(2 * self.batch_size, self.patch,1),torch.zeros(2 * self.batch_size, self.patch, negatives.shape[-1])),dim=-1).to(self.device).long()
        # Add poly loss
        pt = torch.mean(onehot_label* torch.nn.functional.softmax(logits,dim=-1)) # 保留相等的那个部分

        epsilon = self.batch_size
        # loss = CE/ (2 * self.batch_size) + epsilon*(1-pt) # replace 1 by 1/self.batch_size
        loss = CE / (2 * self.batch_size) + epsilon * (1/self.batch_size - pt)
        # loss = CE / (2 * self.batch_size)

        return loss


class NTXentLoss_poly_2D(torch.nn.Module):

    def __init__(self, device, batch_size, temperature,patch, use_cosine_similarity):
        super(NTXentLoss_poly_2D, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device
        self.patch = patch
        self.softmax = torch.nn.Softmax(dim=-1)
        self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
        self.similarity_function = self._get_similarity_function(use_cosine_similarity)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    def _get_similarity_function(self, use_cosine_similarity):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_simililarity
        else:
            return self._dot_simililarity

    def _get_correlated_mask(self):
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = torch.from_numpy((diag + l1 + l2))
        mask = (1 - mask).type(torch.bool)
        return mask.to(self.device)

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
        # x shape: (N, 1, C)
        # y shape: (1, C, 2N)
        # v shape: (N, 2N)
        return v

    def _cosine_simililarity(self, x, y):
        # x shape: (N, 1, C)
        # y shape: (1, 2N, C)
        # v shape: (N, 2N)
        v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, zis, zjs):
        # print( "2d",zis.size(),zjs.size() )
        # batch * patch * seq_length * feature -> batch * patch * seq_length
        zis = torch.squeeze(zis)
        zjs = torch.squeeze(zjs)
#         print('zis,zjs.shape',zis.size(),zjs.size()) # 原来：batch * seq_length | 后来：batch * patch * seq_length 
        representations = torch.cat([zjs, zis], dim=0) # 原来：2 * batch * feature | 后来：(2 * batch) * patch * seq_length
#         print('representation: ',representations.size())
        similarity_matrix = self.similarity_function(representations,representations)
#         print('similarity_matrix: ',similarity_matrix.size())
        # filter out the scores from the positive samples
        l_pos = torch.diag(similarity_matrix, self.batch_size) 
        r_pos = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) # 提取的是当前正例的部分
       # 到目前为止，计算得到正例的结果，positive的结果

        # 这里是计算负例的相似度：

        negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1) # 提取的是当前负例的部分，刨除了我们的相同batch对应的正例
            
        logits = torch.cat((positives, negatives), dim=1)# 得到[positive,negatives] 原来：2*batch*(2*batch -1)|后来：(2*batch) * patch * (2 * batch)
        logits /= self.temperature

        """Criterion has an internal one-hot function. Here, make all positives as 1 while all negatives as 0. """
#         labels = torch.zeros(2 * self.batch_size).to(self.device).long() # 可以计算的原：2*batch,1|后来：2*batch * patch * 1
## pay attention to this!!!, 试一下是self.batch_size合适还是self.batch_size * patch合适
        labels = torch.zeros(2 * self.batch_size).to(self.device).long()
#         print("lables: ", labels.size(),'logits: ',logits.size())
        CE = self.criterion(logits, labels)  # 使用crossEntropy写所有的部分，
        
        onehot_label = torch.cat((torch.ones(2 * self.batch_size,1),torch.zeros(2 * self.batch_size,negatives.shape[-1])),dim=-1).to(self.device).long()
        # Add poly loss
#         print("onehot_label: ",onehot_label.size())
        pt = torch.mean(onehot_label* torch.nn.functional.softmax(logits,dim=-1)) # 保留相等的那个部分

        epsilon = self.batch_size
        # loss = CE/ (2 * self.batch_size) + epsilon*(1-pt) # replace 1 by 1/self.batch_size
        loss = CE / (2 * self.batch_size) + epsilon * (1/self.batch_size - pt)
        # loss = CE / (2 * self.batch_size)

        return loss
    

class Triple_loss(torch.nn.Module):
    def __init__(self,device,batch_size,temperature,patch,use_cosine_similarity):
        super(Triple_loss,self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.patch = patch
        self.use_cosine_similarity

    def forward(self,zis,zjs):
        B, T = zis.size(0),zis.size(1)
        if T == 1:
            return zis.new_tensor(0.)
        z = torch.cat([zis,zjs],dim=1)
        sim = torch.matmul(z,z.transpose(1,2))
        logits = torch.tril(sim,diagonal = -1)[:,:,:-1]
        logits += torch.triu(sim,diagonal = 1)[:,:,1:]
        logits = -F.log_softmax(logits,dim=-1)

        t = torch.arange(T,device=zis.device)
        loss = (logits[:,t,T+t-1].mean()+logits[:,T+t,t].mean())/2 # 这里需要改
        return loss



      